import logging
import time
from typing import Callable, List

from django.db import connection
from django.db.backends.utils import CursorWrapper
from psycopg2.sql import SQL

from zerver.models import UserProfile

"""
NOTE!  Be careful modifying this library, as it is used
in a migration, and it needs to be valid for the state
of the database that is in place when the 0104_fix_unreads
migration runs.
"""

logger = logging.getLogger("zulip.fix_unreads")
logger.setLevel(logging.WARNING)


def build_topic_mute_checker(
    cursor: CursorWrapper, user_profile: UserProfile
) -> Callable[[int, str], bool]:
    """
    This function is similar to the function of the same name
    in zerver/lib/topic_mutes.py, but it works without the ORM,
    so that we can use it in migrations.
    """
    query = SQL(
        """
        SELECT
            recipient_id,
            topic_name
        FROM
            zerver_usertopic
        WHERE
            user_profile_id = %s
    """
    )
    cursor.execute(query, [user_profile.id])
    rows = cursor.fetchall()

    tups = {(recipient_id, topic_name.lower()) for (recipient_id, topic_name) in rows}

    def is_muted(recipient_id: int, topic: str) -> bool:
        return (recipient_id, topic.lower()) in tups

    return is_muted


def update_unread_flags(cursor: CursorWrapper, user_message_ids: List[int]) -> None:
    query = SQL(
        """
        UPDATE zerver_usermessage
        SET flags = flags | 1
        WHERE id IN %(user_message_ids)s
    """
    )

    cursor.execute(query, {"user_message_ids": tuple(user_message_ids)})


def get_timing(message: str, f: Callable[[], None]) -> None:
    start = time.time()
    logger.info(message)
    f()
    elapsed = time.time() - start
    logger.info("elapsed time: %.03f\n", elapsed)


def fix_unsubscribed(cursor: CursorWrapper, user_profile: UserProfile) -> None:

    recipient_ids = []

    def find_recipients() -> None:
        query = SQL(
            """
            SELECT
                zerver_subscription.recipient_id
            FROM
                zerver_subscription
            INNER JOIN zerver_recipient ON (
                zerver_recipient.id = zerver_subscription.recipient_id
            )
            WHERE (
                zerver_subscription.user_profile_id = %(user_profile_id)s AND
                zerver_recipient.type = 2 AND
                (NOT zerver_subscription.active)
            )
        """
        )
        cursor.execute(query, {"user_profile_id": user_profile.id})
        rows = cursor.fetchall()
        for row in rows:
            recipient_ids.append(row[0])
        logger.info(str(recipient_ids))

    get_timing(
        "get recipients",
        find_recipients,
    )

    if not recipient_ids:
        return

    user_message_ids = []

    def find() -> None:
        query = SQL(
            """
            SELECT
                zerver_usermessage.id
            FROM
                zerver_usermessage
            INNER JOIN zerver_message ON (
                zerver_message.id = zerver_usermessage.message_id
            )
            WHERE (
                zerver_usermessage.user_profile_id = %(user_profile_id)s AND
                (zerver_usermessage.flags & 1) = 0 AND
                zerver_message.recipient_id in %(recipient_ids)s
            )
        """
        )

        cursor.execute(
            query,
            {
                "user_profile_id": user_profile.id,
                "recipient_ids": tuple(recipient_ids),
            },
        )
        rows = cursor.fetchall()
        for row in rows:
            user_message_ids.append(row[0])
        logger.info("rows found: %d", len(user_message_ids))

    get_timing(
        "finding unread messages for non-active streams",
        find,
    )

    if not user_message_ids:
        return

    def fix() -> None:
        update_unread_flags(cursor, user_message_ids)

    get_timing(
        "fixing unread messages for non-active streams",
        fix,
    )


def fix(user_profile: UserProfile) -> None:
    logger.info("\n---\nFixing %s:", user_profile.id)
    with connection.cursor() as cursor:
        fix_unsubscribed(cursor, user_profile)
